Skip to content

Support async functions in map()#7384

Merged
lhoestq merged 10 commits intomainfrom
support-async-functions-in-map
Feb 13, 2025
Merged

Support async functions in map()#7384
lhoestq merged 10 commits intomainfrom
support-async-functions-in-map

Conversation

@lhoestq
Copy link
Member

@lhoestq lhoestq commented Feb 3, 2025

e.g. to download images or call an inference API like HF or vLLM

import asyncio
import random

from datasets import Dataset


async def f(x):
    await asyncio.sleep(random.random())

ds = Dataset.from_dict({"data": range(100)})
ds.map(f)
# Map: 100%|█████████████████████████████| 100/100 [00:01<00:00, 99.81 examples/s]

TODO

  • clean code (right now it's a big copy paste)
  • batched
  • Dataset.map()
  • IterableDataset.map()
  • Dataset.filter()
  • IterableDataset.filter()
  • test
  • docs

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@lhoestq lhoestq marked this pull request as ready for review February 12, 2025 15:31
@lhoestq
Copy link
Member Author

lhoestq commented Feb 12, 2025

example of what you can do with it:

import aiohttp
from huggingface_hub import get_token

from datasets import Dataset


API_URL = "https://api-inference.huggingface.co/models/microsoft/Phi-3-mini-4k-instruct/v1/chat/completions"
PROMPT = "What is this text mainly about ? Here is the text:\n\n```\n{Problem}\n```\n\nReply in one or two words."

async def query(example):
    headers = {"Authorization": f"Bearer {get_token()}", "Content-Type": "application/json"}
    json = {"messages": [{"role": "user", "content": PROMPT.format(Problem=example["Problem"])}], "max_tokens": 20, "seed": 42}
    async with aiohttp.ClientSession() as session, session.post(API_URL, headers=headers, json=json) as response:
        output = await response.json()
        return {"output": output["choices"][0]["message"]["content"]}

ds = Dataset.from_dict({"Problem": ["1 + 1"] * 10})
ds = ds.map(query)
print(ds[0])
# {'Problem': '1 + 1', 'output': 'Arithmetic'}

@lhoestq lhoestq merged commit 339e9dc into main Feb 13, 2025
12 of 15 checks passed
@lhoestq lhoestq deleted the support-async-functions-in-map branch February 13, 2025 14:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants